Comparing ESM-based models and RNASamba models for predicting coding and noncoding transcripts¶
Keith Cheveralls
March 2024
This notebook documents the visualizations that were used to compare the performance of ESM-based models and RNASamba models trained to predict whether transcripts are coding or noncoding. This was motivated by developing an approach that used ESM embeddings to identifying sORFs for the peptigate pipeline.
The predictions from ESM-based models and RNASamba models on which this notebook depends were generated outside of this notebook. Predictions from ESM-based models were generated using the commands namespaced under the plmutils orf-classification CLI. Predictions from RNASamba models were generated using the script found in the /scripts/rnasamba subdirectory of this repo. The CLI commands that were used are briefly documented in the sections below.
import io
import pathlib
import pandas as pd
import seaborn as sns
import numpy as np
from Bio import SeqIO
import matplotlib.pyplot as plt
from plmutils.models import calc_metrics
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'
/tmp/ipykernel_19454/2906734431.py:3: DeprecationWarning: Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0), (to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries) but was not found to be installed on your system. If this would cause problems for you, please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466 import pandas as pd
Dataset metadata¶
The metadata associated with the 16 species used for these comparisons is included below for completeness. Note that the plots in this notebook label species using the species_id defined in this metadata (rather than the full species name).
metadata_csv_content = """
species_id species_common_name root_url genome_name cdna_endpoint ncrna_endpoint genome_abbreviation
hsap human https://ftp.ensembl.org/pub/release-111/fasta/homo_sapiens/ Homo_sapiens.GRCh38 cdna/Homo_sapiens.GRCh38.cdna.all.fa.gz ncrna/Homo_sapiens.GRCh38.ncrna.fa.gz GRCh38
scer yeast https://ftp.ensemblgenomes.ebi.ac.uk/pub/fungi/release-58/fasta/saccharomyces_cerevisiae/ Saccharomyces_cerevisiae.R64-1-1 cdna/Saccharomyces_cerevisiae.R64-1-1.cdna.all.fa.gz ncrna/Saccharomyces_cerevisiae.R64-1-1.ncrna.fa.gz R64-1-1
cele worm https://ftp.ensemblgenomes.ebi.ac.uk/pub/metazoa/release-58/fasta/caenorhabditis_elegans/ Caenorhabditis_elegans.WBcel235 cdna/Caenorhabditis_elegans.WBcel235.cdna.all.fa.gz ncrna/Caenorhabditis_elegans.WBcel235.ncrna.fa.gz WBcel235
atha arabadopsis https://ftp.ensemblgenomes.ebi.ac.uk/pub/plants/release-58/fasta/arabidopsis_thaliana/ Arabidopsis_thaliana.TAIR10 cdna/Arabidopsis_thaliana.TAIR10.cdna.all.fa.gz ncrna/Arabidopsis_thaliana.TAIR10.ncrna.fa.gz TAIR10
dmel drosophila https://ftp.ensemblgenomes.ebi.ac.uk/pub/metazoa/release-58/fasta/drosophila_melanogaster/ Drosophila_melanogaster.BDGP6.46 cdna/Drosophila_melanogaster.BDGP6.46.cdna.all.fa.gz ncrna/Drosophila_melanogaster.BDGP6.46.ncrna.fa.gz BDGP6.46
ddis dictyostelium_discoideum https://ftp.ensemblgenomes.ebi.ac.uk/pub/protists/release-58/fasta/dictyostelium_discoideum/ Dictyostelium_discoideum.dicty_2.7 cdna/Dictyostelium_discoideum.dicty_2.7.cdna.all.fa.gz ncrna/Dictyostelium_discoideum.dicty_2.7.ncrna.fa.gz dicty_2.7
mmus mouse https://ftp.ensembl.org/pub/release-111/fasta/mus_musculus/ Mus_musculus.GRCm39 cdna/Mus_musculus.GRCm39.cdna.all.fa.gz ncrna/Mus_musculus.GRCm39.ncrna.fa.gz GRCm39
drer zebrafish https://ftp.ensembl.org/pub/release-111/fasta/danio_rerio/ Danio_rerio.GRCz11 cdna/Danio_rerio.GRCz11.cdna.all.fa.gz ncrna/Danio_rerio.GRCz11.ncrna.fa.gz GRCz11
ggal chicken https://ftp.ensembl.org/pub/release-111/fasta/gallus_gallus/ Gallus_gallus.bGalGal1.mat.broiler.GRCg7b cdna/Gallus_gallus.bGalGal1.mat.broiler.GRCg7b.cdna.all.fa.gz ncrna/Gallus_gallus.bGalGal1.mat.broiler.GRCg7b.ncrna.fa.gz bGalGal1.mat.broiler.GRCg7b
oind rice https://ftp.ensemblgenomes.ebi.ac.uk/pub/plants/release-58/fasta/oryza_indica/ Oryza_indica.ASM465v1 cdna/Oryza_indica.ASM465v1.cdna.all.fa.gz ncrna/Oryza_indica.ASM465v1.ncrna.fa.gz ASM465v1
zmay maize https://ftp.ensemblgenomes.ebi.ac.uk/pub/plants/release-58/fasta/zea_mays/ Zea_mays.Zm-B73-REFERENCE-NAM-5.0 cdna/Zea_mays.Zm-B73-REFERENCE-NAM-5.0.cdna.all.fa.gz ncrna/Zea_mays.Zm-B73-REFERENCE-NAM-5.0.ncrna.fa.gz Zm-B73-REFERENCE-NAM-5.0
xtro frog https://ftp.ensembl.org/pub/release-111/fasta/xenopus_tropicalis/ Xenopus_tropicalis.UCB_Xtro_10.0 cdna/Xenopus_tropicalis.UCB_Xtro_10.0.cdna.all.fa.gz ncrna/Xenopus_tropicalis.UCB_Xtro_10.0.ncrna.fa.gz UCB_Xtro_10.0
rnor rat https://ftp.ensembl.org/pub/release-111/fasta/rattus_norvegicus/ Rattus_norvegicus.mRatBN7.2 cdna/Rattus_norvegicus.mRatBN7.2.cdna.all.fa.gz ncrna/Rattus_norvegicus.mRatBN7.2.ncrna.fa.gz mRatBN7
amel honeybee https://ftp.ensemblgenomes.ebi.ac.uk/pub/metazoa/release-58/fasta/apis_mellifera/ Apis_mellifera.Amel_HAv3.1 cdna/Apis_mellifera.Amel_HAv3.1.cdna.all.fa.gz ncrna/Apis_mellifera.Amel_HAv3.1.ncrna.fa.gz Amel_HAv3.1
spom fission_yeast https://ftp.ensemblgenomes.ebi.ac.uk/pub/fungi/release-58/fasta/schizosaccharomyces_pombe/ Schizosaccharomyces_pombe.ASM294v2 cdna/Schizosaccharomyces_pombe.ASM294v2.cdna.all.fa.gz ncrna/Schizosaccharomyces_pombe.ASM294v2.ncrna.fa.gz ASM294v2
tthe tetrahymena https://ftp.ensemblgenomes.ebi.ac.uk/pub/protists/release-58/fasta/tetrahymena_thermophila/ Tetrahymena_thermophila.JCVI-TTA1-2.2 cdna/Tetrahymena_thermophila.JCVI-TTA1-2.2.cdna.all.fa.gz ncrna/Tetrahymena_thermophila.JCVI-TTA1-2.2.ncrna.fa.gz JCVI-TTA1-2.2
"""
metadata = pd.read_csv(io.StringIO(metadata_csv_content), sep='\t')
metadata.head()
| species_id | species_common_name | root_url | genome_name | cdna_endpoint | ncrna_endpoint | genome_abbreviation | |
|---|---|---|---|---|---|---|---|
| 0 | hsap | human | https://ftp.ensembl.org/pub/release-111/fasta/... | Homo_sapiens.GRCh38 | cdna/Homo_sapiens.GRCh38.cdna.all.fa.gz | ncrna/Homo_sapiens.GRCh38.ncrna.fa.gz | GRCh38 |
| 1 | scer | yeast | https://ftp.ensemblgenomes.ebi.ac.uk/pub/fungi... | Saccharomyces_cerevisiae.R64-1-1 | cdna/Saccharomyces_cerevisiae.R64-1-1.cdna.all... | ncrna/Saccharomyces_cerevisiae.R64-1-1.ncrna.f... | R64-1-1 |
| 2 | cele | worm | https://ftp.ensemblgenomes.ebi.ac.uk/pub/metaz... | Caenorhabditis_elegans.WBcel235 | cdna/Caenorhabditis_elegans.WBcel235.cdna.all.... | ncrna/Caenorhabditis_elegans.WBcel235.ncrna.fa.gz | WBcel235 |
| 3 | atha | arabadopsis | https://ftp.ensemblgenomes.ebi.ac.uk/pub/plant... | Arabidopsis_thaliana.TAIR10 | cdna/Arabidopsis_thaliana.TAIR10.cdna.all.fa.gz | ncrna/Arabidopsis_thaliana.TAIR10.ncrna.fa.gz | TAIR10 |
| 4 | dmel | drosophila | https://ftp.ensemblgenomes.ebi.ac.uk/pub/metaz... | Drosophila_melanogaster.BDGP6.46 | cdna/Drosophila_melanogaster.BDGP6.46.cdna.all... | ncrna/Drosophila_melanogaster.BDGP6.46.ncrna.f... | BDGP6.46 |
Heatmap plotting functions¶
These are functions used later in the notebook to generate heatmap visualizations of the matrices of model performance metrics for all pairs of training and test species.
def plot_heatmap(df, column='accuracy', model_name='unknown', ax=None, **heatmap_kwargs):
"""
Plot the values in the given column as a square heatmap of training vs test species
(with training species on the x-axis and test species on the y-axis).
Note: "training species" is the species used to train the model and "test species"
is the species used to test each trained model.
"""
df = df.pivot(index='test_species_id', columns='training_species_id', values=column)
if ax is None:
plt.figure(figsize=(8, 6))
ax = plt.gca()
sns.heatmap(
df,
cmap="coolwarm",
annot=True,
annot_kws={"size": 6},
fmt=".1f",
square=True,
ax=ax,
**heatmap_kwargs
)
name = column.replace('_', ' ')
if name.lower() == 'mcc':
name = name.upper()
else:
name = name[0].upper() + name[1:]
ax.set_xlabel('Training species')
ax.set_ylabel('Test species')
ax.set_title(f'{name} | {model_name}')
def plot_heatmaps(df_left, df_right, column, model_names):
"""
Plot a row of three heatmaps: one for the left dataframe, one for the right dataframe,
and the third (the rightmost) for the difference between the two (right minus left).
"""
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(18, 5))
df_merged = pd.merge(df_left, df_right, on=('training_species_id', 'test_species_id'))
df_merged[column] = df_merged[f'{column}_y'] - df_merged[f'{column}_x']
plot_heatmap(df_left, column=column, model_name=model_names[0], ax=axs[0])
plot_heatmap(df_right, column=column, model_name=model_names[1], ax=axs[1])
plot_heatmap(df_merged, column=column, model_name='difference', ax=axs[2], vmin=-1, vmax=1)
ESM-based model predictions¶
These predictions were generated using the plmutils orf-prediction CLI.
First, download the Ensembl datasets listed in the user-provided metadata CSV file (see above for the file used with this notebook):
plmutils orf-prediction download-data \
output/data/ensembl-dataset-metadata.tsv \
output/data/
Next, construct deduplicated sets of coding and noncoding transcripts. Deduplication is achieved by clustering transcripts by sequence identity and retaining only one representative sequence from each cluster.
plmutils orf-prediction construct-data \
output/data/ensembl-dataset-metadata.tsv \
output/data/ \
--subsample-factor 1
Next, find putative ORFs from coding and noncoding transcripts, retain only the longest putative ORF from each transcript, and generate the embedding of the protein sequence for which it codes:
plmutils orf-prediction translate-and-embed \
output/data/processed/final/coding-dedup-ssx1/transcripts
plmutils orf-prediction translate-and-embed \
output/data/processed/final/noncoding-dedup-ssx1/transcripts
Finally, train models using these embeddings to predict whether a given ORF orginated from a coding or noncoding transcript. Separate models are trained on, and used to make predictions for, each species. This results in a matrix of model performance metrics for all pairs of species (one used to train the model, the other to evaluate it). The --output-dirpath in the command below corresponds to the directories passed to the calc_metrics_from_smallesm_results function defined below. (This command was run manually with and without --max-length 100 to train models on all ORFs and only sORFs, respectively).
plmutils orf-prediction train-and-evaluate \
--coding-dirpath output/data/processed/final/coding-dedup-ssx1/embeddings/esm2_t6_8M_UR50D \
--noncoding-dirpath output/data/processed/final/noncoding-dedup-ssx1/embeddings/esm2_t6_8M_UR50D \
--output-dirpath output/data/esm-model-results-ssx1-all
def calc_metrics_from_smallesm_results(results_dirpath, max_length=None):
"""
Calculate classification metrics from ESM-based model results.
"""
all_metrics = []
prediction_filepaths = pathlib.Path(results_dirpath).glob('*.csv')
for prediction_filepath in prediction_filepaths:
df = pd.read_csv(prediction_filepath)
if max_length is not None:
df = df.loc[df.sequence_length < max_length]
metrics = calc_metrics(
y_true=(df.true_label == 'coding'),
y_pred_proba=df.predicted_probability.values,
)
metrics['training_species_id'] = df.iloc[0].training_species_id
metrics['test_species_id'] = df.iloc[0].testing_species_id
metrics['num_coding'] = (df.true_label == 'coding').sum()
metrics['num_noncoding'] = (df.true_label != 'coding').sum()
all_metrics.append(metrics)
df = pd.DataFrame(all_metrics)
df['true_negative_rate'] = df.num_true_negative / df.num_noncoding
return df
metrics_esm_trained_all_eval_all = calc_metrics_from_smallesm_results(
'../output/results/2024-03-01-esm-model-results-ssx1-all/',
max_length=None,
)
metrics_esm_trained_all_eval_short = calc_metrics_from_smallesm_results(
'../output/results/2024-03-01-esm-model-results-ssx1-all/',
max_length=100,
)
metrics_esm_trained_short_eval_all = calc_metrics_from_smallesm_results(
'../output/results/2024-02-29-esm-model-results-ssx1-max-length-100/',
max_length=None,
)
metrics_esm_trained_short_eval_short = calc_metrics_from_smallesm_results(
'../output/results/2024-02-29-esm-model-results-ssx1-max-length-100/',
max_length=100,
)
metrics_esm_trained_all_eval_all.head()
| auc_roc | accuracy | precision | recall | mcc | num_true_positive | num_false_positive | num_true_negative | num_false_negative | num_positive | num_negative | training_species_id | test_species_id | num_coding | num_noncoding | true_negative_rate | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0.803246 | 0.820781 | 0.818813 | 0.998590 | 0.246917 | 18411 | 4074 | 366 | 26 | 18437 | 4440 | ddis | rnor | 18437 | 4440 | 0.082432 |
| 1 | 0.957956 | 0.949619 | 0.969926 | 0.970958 | 0.799962 | 15513 | 481 | 2299 | 464 | 15977 | 2780 | amel | dmel | 15977 | 2780 | 0.826978 |
| 2 | 0.825174 | 0.859946 | 0.858987 | 0.999687 | 0.216374 | 15972 | 2622 | 158 | 5 | 15977 | 2780 | oind | dmel | 15977 | 2780 | 0.056835 |
| 3 | 0.861818 | 0.500659 | 1.000000 | 0.500417 | 0.021989 | 11393 | 0 | 11 | 11374 | 22767 | 11 | rnor | tthe | 22767 | 11 | 1.000000 |
| 4 | 0.986828 | 0.956949 | 0.972110 | 0.973813 | 0.867413 | 19449 | 558 | 4580 | 523 | 19972 | 5138 | rnor | cele | 19972 | 5138 | 0.891397 |
Compare ESM-based models trained on all ORFs and only sORFs¶
# models trained on either all ORFs or only sORFs and evaluated on only sORFs.
plot_heatmaps(
metrics_esm_trained_all_eval_short,
metrics_esm_trained_short_eval_short,
column='mcc',
model_names=('ESM-based (trained all, eval short)', 'ESM-based (trained short, eval short)')
)
# models trained only on sORFs and evaluated on all or only sORFs.
plot_heatmaps(
metrics_esm_trained_short_eval_all,
metrics_esm_trained_short_eval_short,
column='mcc',
model_names=('ESM-based (trained short, eval all)', 'ESM-based (trained short, eval short)')
)
# models trained on all ORFs or only sORFs, but evaluated on all sequences.
plot_heatmaps(
metrics_esm_trained_all_eval_all,
metrics_esm_trained_short_eval_all,
column='mcc',
model_names=('ESM-based (trained all, eval all)', 'ESM-based (trained short, eval all)')
)
RNASamba predictions¶
These predictions were generated by the script plm-utils/scripts/rnasamba/train_and_evaluate.py using the same datasets of deduplicated coding and noncoding transcripts generated by the plmutils orf-prediction construct-data command describe above.
To train RNASamba models on all sequences:
python scripts/rnasamba-comparison/train_and_evaluate.py \
--coding-dirpath output/data/processed/final/coding-dedup-ssx1/transcripts \
--noncoding-dirpath output/data/processed/final/noncoding-dedup-ssx1/transcripts \
--output-dirpath 2024-02-28-rnasamba-results-ssx1-all \
To train RNASamba models on transcripts corresponding to sORFs:
python scripts/rnasamba-comparison/train_and_evaluate.py \
--coding-dirpath output/data/processed/final/coding-dedup-ssx1/transcripts \
--noncoding-dirpath output/data/processed/final/noncoding-dedup-ssx1/transcripts \
--output-dirpath output/data/2024-02-28-rnasamba-results-ssx1-min-peptide-length-100 \
--max-length 100
The --output-dirpath above corresponds to the directory passed to the calc_metrics_from_rnasamba_results function below.
def calc_metrics_from_rnasamba_results(rnasamba_results_dirpath):
"""
Aggregate the results from RNASamba models trained in the script
`scripts/rnasamba-comparison/train_and_evaluate.py`.
"""
all_metrics = []
dirpaths = [p for p in rnasamba_results_dirpath.glob('trained-on*') if p.is_dir()]
for dirpath in dirpaths:
# dirnames are of the form 'trained-on-{species_id}-filtered'.
training_species_id = dirpath.stem.split('-')[2]
prediction_filepaths = dirpath.glob('*.tsv')
for prediction_filepath in prediction_filepaths:
# filenames are of the form '{species_id}-preds.csv'.
test_species_id = prediction_filepath.stem.split('-')[0]
df = pd.read_csv(prediction_filepath, sep=',')
metrics = calc_metrics(
y_true=(df.true_label == 'coding'), y_pred_proba=df.coding_score.values
)
metrics['training_species_id'] = training_species_id
metrics['test_species_id'] = test_species_id
metrics['num_coding'] = (df.true_label == 'coding').sum()
metrics['num_noncoding'] = (df.true_label != 'coding').sum()
all_metrics.append(metrics)
df = pd.DataFrame(all_metrics)
df['true_negative_rate'] = df.num_true_negative / df.num_noncoding
return df
# models trained and tested on all transcripts.
rnasamba_results_dirpath_all = pathlib.Path(
'../output/results/2024-02-23-rnasamba-models-clustered-ssx3/'
)
# models trained and tested only on transcripts whose longest ORFs are sORFs.
rnasamba_results_dirpath_short = pathlib.Path(
'../output/results/2024-02-28-rnasamba-results-ssx1-max-peptide-length-100/'
)
metrics_rs_trained_all_eval_all = calc_metrics_from_rnasamba_results(rnasamba_results_dirpath_all)
metrics_rs_trained_short_eval_short = calc_metrics_from_rnasamba_results(rnasamba_results_dirpath_short)
/home/keith/miniforge3/envs/esm-py311-env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1497: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/home/keith/miniforge3/envs/esm-py311-env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1497: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/home/keith/miniforge3/envs/esm-py311-env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1497: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/home/keith/miniforge3/envs/esm-py311-env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1497: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/home/keith/miniforge3/envs/esm-py311-env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1497: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/home/keith/miniforge3/envs/esm-py311-env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1497: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
Compare RNASamba models trained on all or only sORFs¶
plot_heatmaps(
metrics_rs_trained_all_eval_all,
metrics_rs_trained_short_eval_short,
column='mcc',
model_names=('RNASamba (all)', 'RNASamba (short)')
)
Compare RNASamba and ESM-based models¶
These are the most important plots in this notebook. They compare the performance of ESM-based models to RNASamba models by plotting the heatmap of performance metrics side by side.
Models trained and evaluated on all transcripts (for RNASamba) or ORFs (for ESM-based)¶
# overall performance (MCC metric)
plot_heatmaps(
metrics_rs_trained_all_eval_all,
metrics_esm_trained_all_eval_all,
column='mcc',
model_names=('RNASamba (all)', 'ESM-based (all)')
)
# recall (also the true positive rate, or num_true_positive / num_coding)
plot_heatmaps(
metrics_rs_trained_all_eval_all,
metrics_esm_trained_all_eval_all,
column='recall',
model_names=('RNASamba (all)', 'ESM-based (all)')
)
# the true negative rate.
plot_heatmaps(
metrics_rs_trained_all_eval_all,
metrics_esm_trained_all_eval_all,
column='true_negative_rate',
model_names=('RNASamba (all)', 'ESM-based (all)')
)
Models trained only on short sequences (< 100aa)¶
For RNASamba, this means the models were trained only on transcripts whose longest ORF was an sORF (less than 100aa long).
Note that the class imbalance in this case is severe (most species do not have many coding transcripts whose longest ORF is an sORF) and this likely at least partly explains why the RNASamba models perform so poorly, as we do not compensate for the class imbalance during training (while we do compensate for it when training the ESM-based models).
plot_heatmaps(
metrics_rs_trained_short_eval_short,
metrics_esm_trained_short_eval_short,
column='mcc',
model_names=('RNASamba (short)', 'ESM-based (short)')
)
plot_heatmaps(
metrics_rs_trained_short_eval_short,
metrics_esm_trained_short_eval_short,
column='recall',
model_names=('RNASamba (short)', 'ESM-based (short)')
)
plot_heatmaps(
metrics_rs_trained_short_eval_short,
metrics_esm_trained_short_eval_short,
column='true_negative_rate',
model_names=('RNASamba (short)', 'ESM-based (short)')
)
Aside: blasting against peptipedia¶
We were curious whether some of the false positives from ESM-based models represented genuine sORFs from lncRNAs (which are annotated as noncoding). As a way to examine this, we blasted all of the putative ORFs against peptipedia, and plotted the distribution of max evalues from putative sORFs for which the ESM-based model made either true and false positive predictions. If the model correctly identifies genuine sORFs from lncRNAs, we'd expect to see an enrichment of low evalues among the false positives.
The command plmutils orf-classification blast-peptipedia was used to generate the directory of blast results that are loaded and concatenated by concat_smallesm_results function below.
def concat_smallesm_results(results_dirpath):
"""
Load and concatenate the predictions from esm-based models.
"""
dfs = []
prediction_filepaths = pathlib.Path(results_dirpath).glob('*.csv')
for prediction_filepath in prediction_filepaths:
dfs.append(pd.read_csv(prediction_filepath))
return pd.concat(dfs)
# predictions from models trained on all putative ORFs.
esm_trained_all_preds = concat_smallesm_results(
'../output/results/2024-03-01-esm-model-results-ssx1-all/'
)
# predictions from models trained on short peptides (< 100aa).
esm_trained_short_preds = concat_smallesm_results(
'../output/results/2024-02-29-esm-model-results-ssx1-max-length-100/'
)
esm_trained_all_preds.shape, esm_trained_short_preds.shape
((7766064, 6), (7766064, 6))
esm_trained_short_preds.head()
| sequence_id | sequence_length | true_label | predicted_probability | training_species_id | testing_species_id | |
|---|---|---|---|---|---|---|
| 0 | RNOR.ENSRNOT00000105380.1 | 124 | coding | 0.306121 | ddis | rnor |
| 1 | RNOR.ENSRNOT00000094775.1 | 265 | coding | 0.436046 | ddis | rnor |
| 2 | RNOR.ENSRNOT00000119508.1 | 594 | coding | 0.868186 | ddis | rnor |
| 3 | RNOR.ENSRNOT00000094997.1 | 690 | coding | 0.568402 | ddis | rnor |
| 4 | RNOR.ENSRNOT00000119131.1 | 390 | coding | 0.739940 | ddis | rnor |
# count the number of peptides from coding and noncoding transcripts to make sure
# that the class imbalance between coding and noncoding is not too severe.
# (we only need to look at preds from one model, since each model is tested with all species).
hsap_preds = esm_trained_all_preds.loc[esm_trained_all_preds.training_species_id == 'hsap'].copy()
pd.merge(
hsap_preds.groupby(['testing_species_id', 'true_label']).count().sequence_id,
(
hsap_preds.loc[hsap_preds.sequence_length < 100]
.groupby(['testing_species_id', 'true_label'])
.count()
.sequence_id
),
left_index=True,
right_index=True,
suffixes=('_all', '_short'),
)
| sequence_id_all | sequence_id_short | ||
|---|---|---|---|
| testing_species_id | true_label | ||
| amel | coding | 11725 | 223 |
| noncoding | 2429 | 1204 | |
| atha | coding | 30734 | 1800 |
| noncoding | 3638 | 3164 | |
| cele | coding | 19972 | 1285 |
| noncoding | 5138 | 5051 | |
| ddis | coding | 11688 | 1220 |
| noncoding | 28 | 28 | |
| dmel | coding | 15977 | 538 |
| noncoding | 2780 | 1617 | |
| drer | coding | 28385 | 2088 |
| noncoding | 3674 | 2343 | |
| ggal | coding | 24448 | 375 |
| noncoding | 19831 | 8242 | |
| hsap | coding | 55672 | 14928 |
| noncoding | 42157 | 23212 | |
| mmus | coding | 32128 | 8062 |
| noncoding | 19171 | 11011 | |
| oind | coding | 28134 | 2515 |
| noncoding | 205 | 200 | |
| rnor | coding | 18437 | 556 |
| noncoding | 4440 | 2840 | |
| scer | coding | 6013 | 368 |
| noncoding | 103 | 99 | |
| spom | coding | 4683 | 178 |
| noncoding | 1032 | 668 | |
| tthe | coding | 22767 | 19158 |
| noncoding | 11 | 11 | |
| xtro | coding | 27543 | 572 |
| noncoding | 244 | 243 | |
| zmay | coding | 39519 | 1002 |
| noncoding | 2673 | 744 |
def concat_blast_results(dirpaths):
"""
Aggregate the blast results generated by `plmutils orf-classification blast-peptipedia`.
"""
blast_results_columns = (
"qseqid sseqid full_sseq pident length qlen slen mismatch gapopen qstart qend sstart send evalue bitscore"
).split(' ')
dfs = []
for dirpath in dirpaths:
filepaths = pathlib.Path(dirpath).glob('*.tsv')
for filepath in filepaths:
try:
df = pd.read_csv(filepath, sep='\t')
except Exception:
continue
df.columns = blast_results_columns
dfs.append(df)
return pd.concat(dfs)
blast_results = concat_blast_results(
[
'../output/data/processed/final/coding-dedup-ssx1/blast-peptipedia-results/',
'../output/data/processed/final/noncoding-dedup-ssx1/blast-peptipedia-results/',
]
)
# use the log of the evalue for readability.
blast_results['evalue'] = np.log(blast_results.evalue)
# we only need to examine the minimum evalue for all hits to each peptide.
min_evalues = blast_results.groupby('qseqid').evalue.min().reset_index()
# merge the minimum evalues with the model predictions.
esm_trained_short_preds_w_evalues = pd.merge(
esm_trained_short_preds, min_evalues, left_on='sequence_id', right_on='qseqid', how='inner'
)
esm_trained_all_preds_w_evalues = pd.merge(
esm_trained_all_preds, min_evalues, left_on='sequence_id', right_on='qseqid', how='inner'
)
esm_trained_short_preds_w_evalues_short_only = esm_trained_short_preds_w_evalues.loc[
esm_trained_short_preds_w_evalues.sequence_length < 100
].copy()
# sanity-check: count the number of peptides that had hits in peptipedia.
(
esm_trained_short_preds_w_evalues_short_only
# we only need to look at one model
.loc[esm_trained_short_preds_w_evalues_short_only.training_species_id == 'hsap']
.groupby(['testing_species_id', 'true_label'])
.count()
[['sequence_id']]
)
| sequence_id | ||
|---|---|---|
| testing_species_id | true_label | |
| amel | coding | 108 |
| noncoding | 10 | |
| atha | coding | 534 |
| noncoding | 17 | |
| cele | coding | 217 |
| noncoding | 4 | |
| ddis | coding | 366 |
| dmel | coding | 121 |
| noncoding | 12 | |
| drer | coding | 273 |
| noncoding | 46 | |
| ggal | coding | 142 |
| noncoding | 5 | |
| hsap | coding | 2470 |
| noncoding | 1308 | |
| mmus | coding | 996 |
| noncoding | 119 | |
| oind | coding | 198 |
| noncoding | 1 | |
| rnor | coding | 257 |
| noncoding | 14 | |
| scer | coding | 339 |
| noncoding | 1 | |
| spom | coding | 144 |
| noncoding | 23 | |
| tthe | coding | 400 |
| xtro | coding | 187 |
| noncoding | 1 | |
| zmay | coding | 113 |
| noncoding | 4 |
Histograms of evalues for coding and noncoding transcripts¶
This was to determine whether the false positives were enriched for peptides that had hits in peptipedia, which would suggest that they correspond to genuine sORFs from lncRNAs (and are therefore not actually false positives).
# we only look at preds for short peptides from the human dataset
# because it is one of the only that has a decent number of short peptides
# with peptipedia hits and are from noncoding transcripts.
preds = esm_trained_all_preds_w_evalues.loc[
(esm_trained_all_preds_w_evalues.training_species_id == 'hsap') &
(esm_trained_all_preds_w_evalues.testing_species_id == 'hsap') &
(esm_trained_all_preds_w_evalues.sequence_length < 100)
]
fig, axs = plt.subplots(1, 2, figsize=(16, 6))
min_min_evalue = -150
bins = np.arange(min_min_evalue, 0, -min_min_evalue/30)
kwargs = dict(bins=bins, density=False, alpha=0.5)
# left axis: coding transcripts
ax = axs[0]
ax.hist(
preds[(preds.true_label == 'coding') & (preds.predicted_probability > 0.5)].evalue,
label='True positives',
color='blue',
**kwargs
)
ax.hist(
preds[(preds.true_label == 'coding') & (preds.predicted_probability < 0.5)].evalue,
label='False negatives',
color='red',
**kwargs
)
ax.legend()
ax.set_xlabel('Minimum log evalue')
ax.set_ylabel('Density')
ax.set_title('Coding transcripts')
# right axis: noncoding transcripts
ax = axs[1]
ax.hist(
preds[(preds.true_label == 'noncoding') & (preds.predicted_probability < 0.5)].evalue,
label='True negatives',
color='blue',
**kwargs
)
_ = ax.hist(
preds[(preds.true_label == 'noncoding') & (preds.predicted_probability > 0.5)].evalue,
label='False positives',
color='red',
**kwargs
)
ax.legend()
ax.set_title('Noncoding transcripts')
Text(0.5, 1.0, 'Noncoding transcripts')